{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于 GCN 的有监督学习\n",
    "\n",
    "图神经网络（GNN）结合了图结构和机器学习的优势. GraphScope提供了处理学习任务的功能。本次教程，我们将会展示GraphScope如何使用GCN算法训练一个模型。\n",
    "\n",
    "本次教程的学习任务是在文献引用网络上的点分类任务。在点分类任务中，算法会确定[Cora](https://linqs.soe.ucsc.edu/data)数据集上每个顶点的标签。在```Cora```数据集中，由学术出版物作为顶点，出版物之间的引用作为边，如果出版物A引用了出版物B，则图中会存在一条从A到B的边。Cora数据集中的节点被分为了七个主题类，我们的模型将会训练来预测出版物顶点的主题。\n",
    "\n",
    "在这一任务中，我们使用图聚合网络(GCN)算法来训练模型。有关这一算法的更多信息可以参考[\"Knowing Your Neighbours: Machine Learning on Graphs\"](https://medium.com/stellargraph/knowing-your-neighbours-machine-learning-on-graphs-9b7c3d0d5896)\n",
    "\n",
    "这一教程将会分为以下几个步骤：\n",
    "\n",
    "- 启动GraphScope的学习引擎，并将图关联到引擎上\n",
    "- 使用内置的GCN模型定义训练过程，并定义相关的超参\n",
    "- 开始训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install graphscope package if you are NOT in the Playground\n",
    "\n",
    "!pip3 install graphscope\n",
    "!pip3 uninstall -y importlib_metadata  # Address an module conflict issue on colab.google. Remove this line if you are not on colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the graphscope module.\n",
    "\n",
    "import graphscope\n",
    "\n",
    "graphscope.set_option(show_log=False)  # enable logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load cora dataset\n",
    "\n",
    "from graphscope.dataset import load_cora\n",
    "\n",
    "graph = load_cora()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后，我们需要定义一个特征列表用于图的训练。训练特征集合必须从点的属性集合中选取。在这个例子中，我们选择了属性集合中所有以\"feat_\"为前缀的属性作为训练特征集，这一特征集也是Cora数据中点的特征集。\n",
    "\n",
    "借助定义的特征列表，接下来，我们使用 [graphlearn](https://graphscope.io/docs/reference/session.html#graphscope.Session.graphlearn) 方法来开启一个学习引擎。\n",
    "\n",
    "在这个例子中，我们在 \"graphlearn\" 方法中，指定在数据中 \"paper\" 类型的顶点和 \"cites\" 类型边上进行模型训练。\n",
    "\n",
    "利用 \"gen_labels\" 参数，我们将 \"paper\" 点数据集进行划分，其中75%作为训练集，10%作为验证集，15%作为测试集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the features for learning\n",
    "paper_features = []\n",
    "for i in range(1433):\n",
    "    paper_features.append(\"feat_\" + str(i))\n",
    "\n",
    "# launch a learning engine.\n",
    "lg = graphscope.graphlearn(\n",
    "    graph,\n",
    "    nodes=[(\"paper\", paper_features)],\n",
    "    edges=[(\"paper\", \"cites\", \"paper\")],\n",
    "    gen_labels=[\n",
    "        (\"train\", \"paper\", 100, (0, 75)),\n",
    "        (\"val\", \"paper\", 100, (75, 85)),\n",
    "        (\"test\", \"paper\", 100, (85, 100)),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "这里我们使用内置的GCN模型定义训练过程。你可以在[Graph Learning Model](https://graphscope.io/docs/learning_engine.html#data-model)获取更多内置学习模型的信息。\n",
    "\n",
    "在本次示例中，我们使用tensorflow作为NN后端训练器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graphscope.learning\n",
    "from graphscope.learning.examples import GCN\n",
    "from graphscope.learning.graphlearn.python.model.tf.optimizer import get_tf_optimizer\n",
    "from graphscope.learning.graphlearn.python.model.tf.trainer import LocalTFTrainer\n",
    "\n",
    "# supervised GCN.\n",
    "\n",
    "\n",
    "def train(config, graph):\n",
    "    def model_fn():\n",
    "        return GCN(\n",
    "            graph,\n",
    "            config[\"class_num\"],\n",
    "            config[\"features_num\"],\n",
    "            config[\"batch_size\"],\n",
    "            val_batch_size=config[\"val_batch_size\"],\n",
    "            test_batch_size=config[\"test_batch_size\"],\n",
    "            categorical_attrs_desc=config[\"categorical_attrs_desc\"],\n",
    "            hidden_dim=config[\"hidden_dim\"],\n",
    "            in_drop_rate=config[\"in_drop_rate\"],\n",
    "            neighs_num=config[\"neighs_num\"],\n",
    "            hops_num=config[\"hops_num\"],\n",
    "            node_type=config[\"node_type\"],\n",
    "            edge_type=config[\"edge_type\"],\n",
    "            full_graph_mode=config[\"full_graph_mode\"],\n",
    "        )\n",
    "\n",
    "    graphscope.learning.reset_default_tf_graph()\n",
    "    trainer = LocalTFTrainer(\n",
    "        model_fn,\n",
    "        epoch=config[\"epoch\"],\n",
    "        optimizer=get_tf_optimizer(\n",
    "            config[\"learning_algo\"], config[\"learning_rate\"], config[\"weight_decay\"]\n",
    "        ),\n",
    "    )\n",
    "    trainer.train_and_evaluate()\n",
    "\n",
    "\n",
    "# define hyperparameters\n",
    "config = {\n",
    "    \"class_num\": 7,  # output dimension\n",
    "    \"features_num\": 1433,\n",
    "    \"batch_size\": 140,\n",
    "    \"val_batch_size\": 300,\n",
    "    \"test_batch_size\": 1000,\n",
    "    \"categorical_attrs_desc\": \"\",\n",
    "    \"hidden_dim\": 128,\n",
    "    \"in_drop_rate\": 0.5,\n",
    "    \"hops_num\": 2,\n",
    "    \"neighs_num\": [5, 5],\n",
    "    \"full_graph_mode\": False,\n",
    "    \"agg_type\": \"gcn\",  # mean, sum\n",
    "    \"learning_algo\": \"adam\",\n",
    "    \"learning_rate\": 0.01,\n",
    "    \"weight_decay\": 0.0005,\n",
    "    \"epoch\": 5,\n",
    "    \"node_type\": \"paper\",\n",
    "    \"edge_type\": \"cites\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在定义完训练过程和超参后，现在我们可以使用学习引擎和定义的超参开始训练过程。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(config, lg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
